import os
import cv2
import time
import numpy as np

imageDir = os.path.abspath("images")
filenames = os.listdir(imageDir)

# 常量类
class Const(object):
	MAX_LEVEL = 8


# 颜色类
class Color(object):

	# Color构造函数
	def __init__(self, r, g, b):
		self.r = r
		self.g = g
		self.b = b

	# 颜色相加
	def add(self, color):
		self.r += color.r
		self.g += color.g
		self.b += color.b

	# 颜色除法
	def div(self, k):
		if k == 0:
			raise Exception("error color div zero.")
		return Color(self.r // k, self.g // k, self.b // k)

	# 根据八叉树原理得到是哪个children
	def getIndex(self, level):
		r = "{0:08b}".format(self.r)[level]
		g = "{0:08b}".format(self.g)[level]
		b = "{0:08b}".format(self.b)[level]
		return int(''.join([r, g, b]), 2)

	def __str__(self):
		return "Color({0}, {1}, {2})".format(self.r, self.g, self.b)

	def __repr__(self):
		return str(self)


# 八叉树节点类
class Node(object):

	# 节点构造函数
	def __init__(self, level, parent):
		self.color = Color(0, 0, 0)						# 节点颜色
		self.level = level  							# 节点所属level
		self.children = [None for i in range(8)]		# 节点拥有的children
		self.pixedCount = 0  							# 相同颜色的个数
		if level < Const.MAX_LEVEL - 1:					# 节点level为7, 则不进octree levels链表
			parent.addLevelNode(level + 1, self)		# 由于root节点level为-1, 因此进行+1操作

	# 递归创建一个Color路径到叶节点
	def addColor(self, color, level, parent):
		if level < Const.MAX_LEVEL:										# level小于8
			index = color.getIndex(level)								# level从0到7, 也就是获取每一列的rgb编码进而得到的索引
			if self.children[index] is None:							# 对应index孩子还没创建
				self.children[index] = Node(level, parent)				# 创建孩子节点, 放在index位置
			self.children[index].addColor(color, level + 1, parent)		# 递归level到下一层
		else:															# level到达第8层, 为叶节点
			self.color.add(color)										# 第8层不需要创建Node, 直接累加第7层的color
			self.pixedCount += 1  										# 累加第7层的color数量

	# 获取包括自身及孩子的叶节点
	def leafNodes(self):
		leafNodes = []											# 记录叶节点
		if self.isLeaf():
			leafNodes.append(self)								# append叶节点
		else:
			for node in self.children:							# 遍历子树
				if not node is None:
					leafNodes = leafNodes + node.leafNodes()	# 拿到子树的叶节点
		return leafNodes										# 返回叶节点

	# 是否为叶节点
	def isLeaf(self):
		return self.pixedCount > 0 		# 若Node的PixedCound大于0, 说明是叶节点

	# 合并孩子节点(该操作在外部调用需要从level=7开始, 也就是从叶子节点一直遍历到root节点; 在这个过程中, 节点A的叶子(孩子)节点被合并, 节点A变为了叶节点)
	def reduce(self):
		reduceCount = 0  								# 合并叶节点数量
		for node in self.children:						# 遍历孩子
			if not node is None:
				self.color.add(node.color)				# 将叶节点颜色值累加到父节点上
				self.pixedCount += node.pixedCount		# 将叶节点像素个数累加到父节点上
				reduceCount += 1 						# 合并计数
		self.children = [None for i in range(8)]  		# 将孩子抛弃掉
		return reduceCount - 1  						# 由于自身Node变为叶节点, 因此+1

	# 均值Node上的Color值
	def normalize(self):
		return self.color.div(self.pixedCount)			# 均值Color


# 八叉树算法
class Octree(object):

	# 八叉树构造函数
	def __init__(self):
		self.levels = [[] for i in range(Const.MAX_LEVEL)] 		# 构建levels链表, 用于后续提取color主题色
		self.root = Node(-1, self)								# root节点

	# 添加Node到levels链表中
	def addLevelNode(self, level, node):
		self.levels[level].append(node)

	# 添加颜色到八叉树中
	def addColor(self, color):
		self.root.addColor(color, 0, self)

	# 提取颜色
	def extractColor(self, k = 256):
		leafCount = len(self.root.leafNodes())		# 获取八叉树叶节点数量
		for i in range(Const.MAX_LEVEL, 0, -1):		# 从level为7开始遍历, 合并叶节点
			level = i - 1   						# 由于i从8开始, 因此-1
			if leafCount <= k:						# 如果叶节点已经小于提取个数k, 结束
				break
			if not self.levels[level] is None:		# 链表不为空, 遍历叶节点的父亲层, 因为在这里是要统计父节点的颜色
				for node in self.levels[level]:		# 遍历链表中的node节点
					leafCount -= node.reduce()		# 统计node的颜色
					if leafCount <= k:				# 如果叶节点已经小于提取个数k, 结束
						break
			self.levels[level] = []					# level置空
		# 提取色
		colors = []
		# 获取合并后八叉树的所有叶节点
		for leafNode in self.root.leafNodes():
			if leafNode.isLeaf() and len(colors) <= k:
				# 获取Node颜色的均值
				colors.append(leafNode.normalize())
		# 返回提取色
		return colors


def octreeColor(imgPath):
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	img = np.array(img)
	width, height, channel = img.shape
	# 创建八叉树
	octree = Octree()
	# 加入颜色
	for i in range(width):
		for j in range(height):
			octree.addColor(Color(img[i, j, 0], img[i, j, 1], img[i, j, 2]))
	# 提取颜色
	k = 16
	colors = octree.extractColor(k)
	# 显示色块
	showColors = np.zeros(shape=(len(colors) * 20, 200, 3), dtype=np.uint8)
	for i in range(len(colors)):
		c = [colors[i].r, colors[i].g, colors[i].b]
		showColors[i*20: i*20+20, :, :] = np.array(c)
	cv2.imshow("", showColors)
	cv2.waitKey(0)

def main():
	octreeColor(os.path.join(imageDir, filenames[0]))

if "__main__" == __name__:
	main()
	exit(0)